import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import pandas as pd # type: ignore
from matplotlib.ticker import PercentFormatter
from matplotlib.ticker import MultipleLocator

def plot_hist_colormaps_diff(freqs = [4000, 6000], case = '_st'):
    box_df = pd.read_csv('comsol_raw/box_colormap.csv', skiprows = 8, index_col = [0,1])

    fig, axarr = plt.subplots(3, len(freqs), sharex='col', figsize = [4.5,4.5])            
    slice_df = pd.read_csv('comsol_raw/slice_colormap'+ case +'.csv', skiprows = 8, index_col = [0,1])

    
    for i, freq in zip(np.arange(len(freqs)),freqs):
        print(freq)
        for j, dir in zip(np.arange(3), ['Z','X','Y']):
            print(j, dir)
                                    
            box_z = np.asarray(box_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freq)])
            slice_z = np.asarray(slice_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freq)])

            print(len(box_z[~np.isnan(box_z)]))
            for k in np.arange(len(box_z)):
                if box_z[k] < 1e-5:  # = 1e-2 nm = 0.01 nm
                    box_z[k] = np.nan
            
            print(len(box_z[~np.isnan(box_z)]))

            z = 20*np.log10(slice_z/box_z)
            z = z[~np.isnan(z)]
            
            print('max:', np.nanmax(np.abs(box_z)))
            
            bins = np.arange(-4,2,0.5)
            axarr[j, i].hist(z, weights = np.ones(len(z))/len(z), bins = bins, color = 'gray')
            if i == 0:
                axarr[j, i].set_ylabel('pixel count')
            if j == 2:
                axarr[j, i].set_xlabel('dB re. full-length')


        for ax in axarr[:,i]:
            ax.plot([0,0],[0,1], linestyle = '-', color = 'black', linewidth = 1)
            ax.set_ylim(0, 1)
            ax.yaxis.set_major_formatter(PercentFormatter(1))
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.set_xlim(-4.2, 2.2)            
            ax.xaxis.set_major_locator(MultipleLocator(2))
            ax.xaxis.set_major_formatter('{x:.0f}')
            ax.xaxis.set_minor_locator(MultipleLocator(1))
            if i == 1:
                ax.spines['left'].set_visible(False)
                ax.set_yticks([])

            
            
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.4)
    plt.savefig('manufigs/hist_colormap_diff' + case + '.pdf', transparent = True)
    plt.show()
    plt.close()

def plot_hist_colormaps_diff_phase(freqs = [4000, 6000], case = '_st'):
    box_df = pd.read_csv('comsol_raw/box_colormap_phase.csv', skiprows = 8, index_col = [0,1])

    fig, axarr = plt.subplots(3, len(freqs), sharex='col', figsize = [4.5,4.5])            
    slice_df = pd.read_csv('comsol_raw/slice_colormap_phase'+ case +'.csv', skiprows = 8, index_col = [0,1])

    
    for i, freq in zip(np.arange(len(freqs)),freqs):
        for j, dir in zip(np.arange(3), ['Z','X','Y']):
            print(j, dir)
                                    
            box_z = np.asarray(box_df.loc[:, 'solid.uPhase'+dir+' (rad) @ freq='+str(freq)])
            slice_z = np.asarray(slice_df.loc[:, 'solid.uPhase'+dir+' (rad) @ freq='+str(freq)])

            for k in np.arange(len(box_z)):
                if box_z[k] < 1e-5:  # = 1e-2 nm = 0.01 nm
                    box_z[k] = np.nan

            if j == 0:
                k = 0
                while np.isnan(box_z[k]):
                    k = k + 1        
                box_shift = box_z[k]
                print(k, box_shift)

                k = 0
                while np.isnan(slice_z[k]):
                    k = k + 1        
                slice_shift = slice_z[k]
                print(k, slice_shift)
            
            box_z = box_z - box_shift
            slice_z = slice_z - slice_shift

            for z in [box_z, slice_z]:
                if j == 0:
                    for k in np.arange(len(z)):
                        if z[k] < -1:
                            z[k] = z[k]+2*np.pi                        
                elif j == 1:
                    for k in np.arange(len(z)):
                        if z[k] < -1:
                            z[k] = z[k]+2*np.pi                
                elif j == 2:
                    for k in np.arange(len(z)):
                        if z[k] > 1:
                            z[k] = z[k]-2*np.pi

            z = (box_z - slice_z)/2/np.pi
            z = z[~np.isnan(z)]
            bins = np.asarray([-1/8, -1/16, -3/64, -1/32, -1/64, 0, 1/64, 1/32, 3/64, 1/16, 5/64, 3/32, 7/64, 1/8, 8/64, 9/64, 5/32, 11/64, 3/16])
            axarr[j, i].hist(z, weights = np.ones(len(z))/len(z), color = 'gray', bins = bins)
            if i == 0:
                axarr[j, i].set_ylabel('pixel count')
            if j == 2:
                axarr[j, i].set_xlabel('diff. re. full-length (cycle)')


        for ax in axarr[:,i]:
            ax.plot([0,0],[0,1], linestyle = '-', color = 'black', linewidth = 1)
            ax.yaxis.set_major_formatter(PercentFormatter(1))
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.set_xlim(-1./16-0.01, 0, 3./32+0.01)
            ax.set_ylim(0,1)
            ax.set_xticks([-1/16, -1/32, 0, 1/32, 1/16, 3/32])
            ax.set_xticklabels(['-1/16', '', '0', '', '1/16', ''])
            if i == 1:
                ax.spines['left'].set_visible(False)
                ax.set_yticks([])

            
            
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.4)
    plt.savefig('manufigs/hist_colormap_diff_phase' + case + '.pdf', transparent = True)
    plt.show()
    plt.close()

plot_hist_colormaps_diff()
plot_hist_colormaps_diff_phase()


